/*
Copyright 2015 The Kubernetes Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package portforward

import (
	"context"
	"errors"
	"fmt"
	"io"
	"net"
	"net/http"
	"regexp"
	"sort"
	"strconv"
	"strings"
	"sync"

	"go.uber.org/atomic"
	klog "k8s.io/klog/v2"

	corev1 "k8s.io/api/core/v1"
	"k8s.io/apimachinery/pkg/util/httpstream"
	"k8s.io/apimachinery/pkg/util/runtime"
	netutils "k8s.io/utils/net"
)

// PortForwardProtocolV1Name is the subprotocol used for port forwarding.
// TODO move to API machinery and re-unify with kubelet/server/portfoward
const PortForwardProtocolV1Name = "portforward.k8s.io"

var (
	// error returned whenever we lost connection to a pod
	ErrLostConnectionToPod = errors.New("lost connection to pod")

	// set of error we're expecting during port-forwarding
	networkClosedError = "use of closed network connection"

	// Containerd will return "network namespace for sandbox %q is closed", see:
	// - 2.0: https://github.com/containerd/containerd/blob/v2.0.0-rc.6/internal/cri/server/sandbox_portforward_linux.go#L47
	// - 1.7: https://github.com/containerd/containerd/blob/release/1.7/pkg/cri/server/sandbox_portforward_linux.go#L47
	// CRI-O will return "sandbox %s is not running", see:
	// - 1.31: https://github.com/cri-o/cri-o/blob/release-1.31/server/container_portforward.go#L54
	networkClosedRe = regexp.MustCompile(`network namespace for sandbox "\w+" is closed|sandbox \w+ is not running`)
)

// PortForwarder knows how to listen for local connections and forward them to
// a remote pod via an upgraded HTTP request.
type PortForwarder struct {
	addresses []listenAddress
	ports     []ForwardedPort
	stopChan  <-chan struct{}

	dialer         httpstream.Dialer
	streamConn     httpstream.Connection
	listeners      []io.Closer
	Ready          chan struct{}
	requestIDLock  sync.Mutex
	requestID      int
	out            io.Writer
	errOut         io.Writer
	errChan        chan<- error
	numConnections atomic.Int64
}

// ForwardedPort contains a Local:Remote port pairing.
type ForwardedPort struct {
	Local  uint16
	Remote uint16
}

/*
valid port specifications:

5000
- forwards from localhost:5000 to pod:5000

8888:5000
- forwards from localhost:8888 to pod:5000

0:5000
:5000
  - selects a random available local port,
    forwards from localhost:<random port> to pod:5000
*/
func parsePorts(ports []string) ([]ForwardedPort, error) {
	var forwards []ForwardedPort
	for _, portString := range ports {
		parts := strings.Split(portString, ":")
		var localString, remoteString string
		if len(parts) == 1 {
			localString = parts[0]
			remoteString = parts[0]
		} else if len(parts) == 2 {
			localString = parts[0]
			if localString == "" {
				// support :5000
				localString = "0"
			}
			remoteString = parts[1]
		} else {
			return nil, fmt.Errorf("invalid port format '%s'", portString)
		}

		localPort, err := strconv.ParseUint(localString, 10, 16)
		if err != nil {
			return nil, fmt.Errorf("error parsing local port '%s': %w", localString, err)
		}

		remotePort, err := strconv.ParseUint(remoteString, 10, 16)
		if err != nil {
			return nil, fmt.Errorf("error parsing remote port '%s': %w", remoteString, err)
		}
		if remotePort == 0 {
			return nil, fmt.Errorf("remote port must be > 0")
		}

		forwards = append(forwards, ForwardedPort{uint16(localPort), uint16(remotePort)})
	}

	return forwards, nil
}

type listenAddress struct {
	address     string
	protocol    string
	failureMode string
}

func parseAddresses(addressesToParse []string) ([]listenAddress, error) {
	var addresses []listenAddress
	parsed := make(map[string]listenAddress)
	for _, address := range addressesToParse {
		if address == "localhost" {
			if _, exists := parsed["127.0.0.1"]; !exists {
				ip := listenAddress{address: "127.0.0.1", protocol: "tcp4", failureMode: "all"}
				parsed[ip.address] = ip
			}
			if _, exists := parsed["::1"]; !exists {
				ip := listenAddress{address: "::1", protocol: "tcp6", failureMode: "all"}
				parsed[ip.address] = ip
			}
		} else if netutils.ParseIPSloppy(address).To4() != nil {
			parsed[address] = listenAddress{address: address, protocol: "tcp4", failureMode: "any"}
		} else if netutils.ParseIPSloppy(address) != nil {
			parsed[address] = listenAddress{address: address, protocol: "tcp6", failureMode: "any"}
		} else {
			return nil, fmt.Errorf("%s is not a valid IP", address)
		}
	}
	addresses = make([]listenAddress, len(parsed))
	id := 0
	for _, v := range parsed {
		addresses[id] = v
		id++
	}
	// Sort addresses before returning to get a stable order
	sort.Slice(addresses, func(i, j int) bool { return addresses[i].address < addresses[j].address })

	return addresses, nil
}

// New creates a new PortForwarder with localhost listen addresses.
func New(dialer httpstream.Dialer, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, errChan chan<- error, out, errOut io.Writer) (*PortForwarder, error) {
	return NewOnAddresses(dialer, []string{"localhost"}, ports, stopChan, readyChan, errChan, out, errOut)
}

// NewOnAddresses creates a new PortForwarder with custom listen addresses.
func NewOnAddresses(dialer httpstream.Dialer, addresses []string, ports []string, stopChan <-chan struct{}, readyChan chan struct{}, errChan chan<- error, out, errOut io.Writer) (*PortForwarder, error) {
	if len(addresses) == 0 {
		return nil, errors.New("you must specify at least 1 address")
	}
	parsedAddresses, err := parseAddresses(addresses)
	if err != nil {
		return nil, err
	}
	if len(ports) == 0 {
		return nil, errors.New("you must specify at least 1 port")
	}
	parsedPorts, err := parsePorts(ports)
	if err != nil {
		return nil, err
	}
	return &PortForwarder{
		dialer:    dialer,
		addresses: parsedAddresses,
		ports:     parsedPorts,
		stopChan:  stopChan,
		Ready:     readyChan,
		out:       out,
		errChan:   errChan,
		errOut:    errOut,
	}, nil
}

func (pf *PortForwarder) raiseError(err error) {
	// make sure this is definitely non blocking
	go func() {
		if pf.errChan != nil {
			pf.errChan <- err
		}
	}()

	if pf.streamConn != nil {
		_ = pf.streamConn.Close()
	}
}

func (pf *PortForwarder) NumConnections() int64 {
	return pf.numConnections.Load()
}

// ForwardPorts formats and executes a port forwarding request. The connection will remain
// open until stopChan is closed.
func (pf *PortForwarder) ForwardPorts(ctx context.Context) error {
	defer pf.Close()

	var err error
	var protocol string
	pf.streamConn, protocol, err = pf.dialer.Dial(PortForwardProtocolV1Name)
	if err != nil {
		return fmt.Errorf("error upgrading connection: %w", err)
	}
	defer pf.streamConn.Close()
	if protocol != PortForwardProtocolV1Name {
		return fmt.Errorf("unable to negotiate protocol: client supports %q, server returned %q", PortForwardProtocolV1Name, protocol)
	}

	return pf.forward(ctx)
}

// forward dials the remote host specific in req, upgrades the request, starts
// listeners for each port specified in ports, and forwards local connections
// to the remote host via streams.
func (pf *PortForwarder) forward(ctx context.Context) error {
	var err error

	listenSuccess := false
	for i := range pf.ports {
		port := &pf.ports[i]
		err = pf.listenOnPort(ctx, port)
		switch err {
		case nil:
			listenSuccess = true
		default:
			if pf.errOut != nil {
				fmt.Fprintf(pf.errOut, "Unable to listen on port %d: %v\n", port.Local, err)
			}
		}
	}

	if !listenSuccess {
		return fmt.Errorf("unable to listen on any of the requested ports: %v, error: %w", pf.ports, err)
	}

	if pf.Ready != nil {
		close(pf.Ready)
	}

	var streamConnCloseChan <-chan bool
	if pf.streamConn != nil {
		streamConnCloseChan = pf.streamConn.CloseChan()
	}

	// wait for interrupt or conn closure
	select {
	case <-pf.stopChan:
	case <-streamConnCloseChan:
		pf.raiseError(ErrLostConnectionToPod)
	}

	return nil
}

// listenOnPort delegates listener creation and waits for connections on requested bind addresses.
// An error is raised based on address groups (default and localhost) and their failure modes
func (pf *PortForwarder) listenOnPort(ctx context.Context, port *ForwardedPort) error {
	var errors []error
	failCounters := make(map[string]int, 2)
	successCounters := make(map[string]int, 2)
	for _, addr := range pf.addresses {
		err := pf.listenOnPortAndAddress(ctx, port, addr.protocol, addr.address)
		if err != nil {
			errors = append(errors, err)
			failCounters[addr.failureMode]++
		} else {
			successCounters[addr.failureMode]++
		}
	}
	if successCounters["all"] == 0 && failCounters["all"] > 0 {
		return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
	}
	if failCounters["any"] > 0 {
		return fmt.Errorf("%s: %v", "Listeners failed to create with the following errors", errors)
	}
	return nil
}

// listenOnPortAndAddress delegates listener creation and waits for new connections
// in the background f
func (pf *PortForwarder) listenOnPortAndAddress(ctx context.Context, port *ForwardedPort, protocol string, address string) error {
	listener, err := pf.getListener(protocol, address, port)
	if err != nil {
		return err
	}
	pf.listeners = append(pf.listeners, listener)
	go pf.waitForConnection(ctx, listener, *port)
	return nil
}

// getListener creates a listener on the interface targeted by the given hostname on the given port with
// the given protocol. protocol is in net.Listen style which basically admits values like tcp, tcp4, tcp6
func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) {
	listener, err := net.Listen(protocol, net.JoinHostPort(hostname, strconv.Itoa(int(port.Local))))
	if err != nil {
		return nil, fmt.Errorf("unable to create listener: Error %w", err)
	}
	listenerAddress := listener.Addr().String()
	host, localPort, _ := net.SplitHostPort(listenerAddress)
	localPortUInt, err := strconv.ParseUint(localPort, 10, 16)
	if err != nil {
		fmt.Fprintf(pf.out, "Failed to forward from %s:%d -> %d\n", hostname, localPortUInt, port.Remote)
		return nil, fmt.Errorf("error parsing local port: %w from %s (%s)", err, listenerAddress, host)
	}
	port.Local = uint16(localPortUInt)
	if pf.out != nil {
		fmt.Fprintf(pf.out, "Forwarding from %s -> %d\n", net.JoinHostPort(hostname, strconv.Itoa(int(localPortUInt))), port.Remote)
	}

	return listener, nil
}

// waitForConnection waits for new connections to listener and handles them in
// the background.
func (pf *PortForwarder) waitForConnection(ctx context.Context, listener net.Listener, port ForwardedPort) {
	for {
		var closeChan <-chan bool
		if pf.streamConn != nil {
			closeChan = pf.streamConn.CloseChan()
		}

		select {
		case <-closeChan:
			return
		default:
			conn, err := listener.Accept()
			if err != nil {
				// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
				if !strings.Contains(strings.ToLower(err.Error()), networkClosedError) {
					pf.raiseError(fmt.Errorf("error accepting connection on port %d: %w", port.Local, err))
				}
				return
			}
			go pf.handleConnection(ctx, conn, port)
		}
	}
}

func (pf *PortForwarder) nextRequestID() int {
	pf.requestIDLock.Lock()
	defer pf.requestIDLock.Unlock()
	id := pf.requestID
	pf.requestID++
	return id
}

// handleConnection copies data between the local connection and the stream to
// the remote server.
func (pf *PortForwarder) handleConnection(ctx context.Context, conn net.Conn, port ForwardedPort) {
	defer conn.Close()

	logger := klog.FromContext(ctx)
	pf.numConnections.Inc()
	defer pf.numConnections.Dec()
	if pf.out != nil {
		fmt.Fprintf(pf.out, "Handling connection for %d\n", port.Local)
	}

	requestID := pf.nextRequestID()

	// create error stream
	headers := http.Header{}
	headers.Set(corev1.StreamType, corev1.StreamTypeError)
	headers.Set(corev1.PortHeader, fmt.Sprintf("%d", port.Remote))
	headers.Set(corev1.PortForwardRequestIDHeader, strconv.Itoa(requestID))

	errorChan := make(chan error)
	if pf.streamConn != nil {
		errorStream, err := pf.streamConn.CreateStream(headers)
		if err != nil {
			pf.raiseError(fmt.Errorf("error creating error stream for port %d -> %d: %w", port.Local, port.Remote, err))
			return
		}
		// we're not writing to this stream
		errorStream.Close()

		go func() {
			message, err := io.ReadAll(errorStream)
			switch {
			case err != nil:
				errorChan <- fmt.Errorf("error reading from error stream for port %d -> %d: %w", port.Local, port.Remote, err)
			case len(message) > 0:
				errorChan <- fmt.Errorf("an error occurred forwarding %d -> %d: %v", port.Local, port.Remote, string(message))
			}
			close(errorChan)
		}()
	} else {
		go func() {
			errorChan <- errors.New("no streamConn available")
			close(errorChan)
		}()
	}

	// create data stream
	var dataStream httpstream.Stream
	localError := make(chan struct{})
	remoteDone := make(chan struct{})
	if pf.streamConn != nil {
		headers.Set(corev1.StreamType, corev1.StreamTypeData)

		var err error
		dataStream, err = pf.streamConn.CreateStream(headers)
		if err != nil {
			pf.raiseError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %w", port.Local, port.Remote, err))
			return
		}

		go func() {
			// Copy from the remote side to the local port.
			if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(strings.ToLower(err.Error()), networkClosedError) {
				if logger := klog.FromContext(ctx).V(4); logger.Enabled() {
					logger.Error(err, "error copying from remote stream to local connection", "debug", true)
				}
			}

			// inform the select below that the remote copy is done
			close(remoteDone)
		}()

		go func() {
			// inform server we're not sending any more data after copy unblocks
			defer dataStream.Close()

			// Copy from the local port to the remote side.
			if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(strings.ToLower(err.Error()), networkClosedError) {
				if logger := klog.FromContext(ctx).V(4); logger.Enabled() {
					logger.Error(err, "error copying from local connection to remote stream", "debug", true)
				}

				// break out of the select below without waiting for the other copy to finish
				close(localError)
			}
		}()
	} else {
		close(localError)
	}

	// wait for either a local->remote error or for copying from remote->local to finish
	select {
	case <-remoteDone:
	case <-localError:
	}

	if dataStream != nil {
		// reset dataStream to discard any unsent data, preventing port forwarding from being blocked.
		// we must reset dataStream before waiting on errorChan, otherwise,
		// the blocking data will affect errorStream and cause <-errorChan to block indefinitely.
		_ = dataStream.Reset()
	}

	// always expect something on errorChan (it may be nil)
	if err := <-errorChan; err != nil {
		// don't tear down the whole parent port-forward pf.streamConn when there's
		// an error handling a single request, container runtime should handle
		// the streaming error by resetting the connection, so we should be able
		// to keep the connection open until a user explicitly requests an end
		if networkClosedRe.MatchString(err.Error()) || strings.Contains(err.Error(), "container") {
			// there are three cases when we consider closing the entire connection:
			// 1. networkClosedRe is happening when a pod is removed, in which case we should stop
			//    port forwarding, although we'll only know about it only during subsequent
			//    connection attempts, where one of the io.Copy goroutines fails.
			// 2. streamTimeoutRe is happening when we're dealing with an older CRI, which doesn't
			//    reset the streams, and thus we only learn about the error when trying to create
			//    a subsequent stream, which fails with that error.
			// 3. Container errors such as "container not running" or "no such container
			pf.raiseError(err)
		} else {
			logger.Error(err, "Failed handling connection")
		}
	}
}

// Close stops all listeners of PortForwarder.
func (pf *PortForwarder) Close() {
	// stop all listeners
	for _, l := range pf.listeners {
		if err := l.Close(); err != nil {
			runtime.HandleError(fmt.Errorf("error closing listener: %w", err))
		}
	}
}

// GetPorts will return the ports that were forwarded; this can be used to
// retrieve the locally-bound port in cases where the input was port 0. This
// function will signal an error if the Ready channel is nil or if the
// listeners are not ready yet; this function will succeed after the Ready
// channel has been closed.
func (pf *PortForwarder) GetPorts() ([]ForwardedPort, error) {
	if pf.Ready == nil {
		return nil, fmt.Errorf("no Ready channel provided")
	}
	select {
	case <-pf.Ready:
		return pf.ports, nil
	default:
		return nil, fmt.Errorf("listeners not ready")
	}
}
